--- external/eg3d/training/dataset.py	2023-04-06 03:55:39.232777639 +0000
+++ external_reference/eg3d/training/dataset.py	2023-04-06 03:41:04.542718663 +0000
@@ -23,6 +23,8 @@
 except ImportError:
     pyspng = None
 
+from utils import midas
+
 #----------------------------------------------------------------------------
 
 class Dataset(torch.utils.data.Dataset):
@@ -91,7 +93,7 @@
         image = self._load_raw_image(self._raw_idx[idx])
         assert isinstance(image, np.ndarray)
         assert list(image.shape) == self.image_shape
-        assert image.dtype == np.uint8
+        # assert image.dtype == np.uint8 # depth is float values
         if self._xflip[idx]:
             assert image.ndim == 3 # CHW
             image = image[:, :, ::-1]
@@ -163,14 +165,33 @@
     def __init__(self,
         path,                   # Path to directory or zip.
         resolution      = None, # Ensure specific resolution, None = highest available.
+        depth_scale     = 16,   # scale factor for depth
+        depth_clip      = 20,   # clip all depths above this value
+        white_sky       = False, # mask sky with white pixels if true
         **super_kwargs,         # Additional arguments for the Dataset base class.
     ):
         self._path = path
         self._zipfile = None
+        self.depth_clip = depth_clip
+        self.depth_scale = depth_scale
+        self.white_sky = white_sky
 
         if os.path.isdir(self._path):
             self._type = 'dir'
-            self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
+            # note: places cache within directory, it contains cache for
+            # disp/seg images but those are filtered out in _image_fnames
+            if os.path.isfile(self._path + '/cache.txt'):
+                with open(self._path + '/cache.txt') as cache:
+                    self._all_fnames = set([line.strip() for line in cache])
+            else:
+                print("Walking dataset...")
+                self._all_fnames = [os.path.relpath(os.path.join(root, fname), start=self._path)
+                                    for root, _dirs, files in
+                                    os.walk(self._path, followlinks=True) for fname in files]
+                with open(self._path + '/cache.txt', 'w') as cache:
+                    [cache.write("%s\n" % fname) for fname in self._all_fnames]
+                self._all_fnames = set(self._all_fnames)
+                print("done walking")
         elif self._file_ext(self._path) == '.zip':
             self._type = 'zip'
             self._all_fnames = set(self._get_zipfile().namelist())
@@ -216,16 +237,63 @@
         return dict(super().__getstate__(), _zipfile=None)
 
     def _load_raw_image(self, raw_idx):
+        ### modified to return RGBD image
+        # and flip disp and sky mask if "mirror" in image path
+
         fname = self._image_fnames[raw_idx]
+        ### load image
         with self._open_file(fname) as f:
-            if pyspng is not None and self._file_ext(fname) == '.png':
-                image = pyspng.load(f.read())
-            else:
-                image = np.array(PIL.Image.open(f))
+            image = np.array(PIL.Image.open(f))
+        w, h, _ = image.shape
+        ### load depth map
+        depth_path = (os.path.join(self._path, fname)
+                      .replace('png', 'pfm')
+                      .replace('img', 'disp')
+                      .replace('_mirror', ''))
+        disp, scale = midas.read_pfm(depth_path)
+        # normalize 0 to 1
+        disp = np.array(disp)
+        dmmin = np.percentile(disp, 1)
+        dmmax = np.percentile(disp, 99)
+        scaled_disp = (disp-dmmin) / (dmmax-dmmin + 1e-6)
+        scaled_disp = np.clip(scaled_disp, 0., 1.) * 255
+        disp_img = PIL.Image.fromarray(scaled_disp.astype(np.uint8))
+        if 'mirror' in fname:
+            scaled_disp = np.fliplr(scaled_disp)
+            disp_img = disp_img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
+
+        ### load sky mask
+        sky_path = (os.path.join(self._path, fname)
+                    .replace('png', 'npz')
+                    .replace('img', 'seg')
+                    .replace('_mirror', ''))
+        sky_mask = np.load(sky_path)['sky_mask']
+        sky_img = PIL.Image.fromarray(sky_mask * 255)
+        if 'mirror' in fname:
+            sky_mask = np.fliplr(sky_mask)
+            sky_img = sky_img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
+
+        # process image
         if image.ndim == 2:
             image = image[:, :, np.newaxis] # HW => HWC
         image = image.transpose(2, 0, 1) # HWC => CHW
-        return image
+
+        # process disparity map
+        disp = scaled_disp / 255 # convert back to [0, 1] range
+        disp_clipped = np.clip(disp, 1/self.depth_clip, 1) # range: [1/clip, 1]
+        psuedo_depth = 1/disp_clipped - 1 # range:[0, clip-1]
+        max_depth = self.depth_clip - 1
+        scaled_depth = psuedo_depth / max_depth * (self.depth_scale - 1) # range: [0, depth_scale-1]
+        scaled_disp = 1/(scaled_depth+1) # range: [1/depth_scale, 1]
+
+        # multiply everything by sky mask
+        scaled_disp = scaled_disp * sky_mask
+
+        if self.white_sky:
+            sky_color = np.array([255, 255, 255]).reshape(-1, 1, 1)
+            image = (image * sky_mask[None] + sky_color * (1-sky_mask[None]))
+
+        return np.concatenate([image, scaled_disp[None]], axis=0)
 
     def _load_raw_labels(self):
         fname = 'dataset.json'
